{

  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "X2Fj4S3r0p1A"
      },
      "source": [
        "##### Copyright 2019 Google LLC.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "Okg-R95R1CaX",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x09idYC2vsZQ",
        "colab_type": "text"
      },
      "source": [
        "# Mesh Segmentation using Feature Steered Graph Convolutions\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/mesh_segmentation_demo.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/mesh_segmentation_demo.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "t4v1coMcWtiJ"
      },
      "source": [
        "# Mesh Segmentation using Feature Steered Graph Convolutions\n",
        "\n",
        "Segmenting a mesh to its semantic parts is an important problem for 3D shape\n",
        "understanding. This colab demonstrates how to build a semantic mesh segmentation\n",
        "model for deformable shapes using graph convolution layers defined in\n",
        "[TensorFlow Graphics](https://github.com/tensorflow/graphics).\n",
        "\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/mesh_segmentation_demo.png)\n",
        "\n",
        "This notebook covers the following key topics:\n",
        "* How to use graph-convolutional layers to define a CNN for mesh segmentation.\n",
        "* How to setup a data pipeline to represent mesh connectivity with SparseTensors.\n",
        "\n",
        "Note: The easiest way to use this tutorial is as a Colab notebook, which allows\n",
        "you to dive in with no setup.\n",
        "\n",
        "### Image Convolutions vs Graph Convolutions\n",
        "\n",
        "Images are represented by uniform grids of pixels. Running convolutions on\n",
        "uniform grids is a well understood process and is at the core of a significant\n",
        "amount of products and academic publications.\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/cat_image_convolutions.png)\n",
        "\n",
        "However, things become a bit more complicated when dealing with three\n",
        "dimensional objects such as meshes or point clouds since these are not defined\n",
        "on regular grids. A convolution operation for meshes or point clouds must\n",
        "operate on irregular data structures. This makes convolutional neural\n",
        "networks based on them harder to implement.\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/cat_mesh_convolutions.png)\n",
        "\n",
        "Any general mesh can be denoted as a graph that is not constrained to a regular grid. Many graph-convolutional operators have been published in\n",
        "the recent years. In this demo we use the method described in\n",
        "[Feature Steered Graph Convolutions](https://arxiv.org/abs/1706.05206). Similar\n",
        "to its image counterpart, this basic building block can be used do solve a\n",
        "plethora of problems. This Colab focuses on segmenting deformable meshes of\n",
        "human bodies into parts (e.g. head, right foot, etc.)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PNQ29y8Q4_cH"
      },
      "source": [
        "## Setup & Imports\n",
        "\n",
        "To run this Colab optimally, please update the runtime type to use a GPU\n",
        "hardware accelerator. - click on the 'Runtime' menu, then 'Change runtime type':\n",
        "\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/non_rigid_deformation/change_runtime.jpg)\n",
        "\n",
        "-   finally, set the 'Hardware accelerator' to 'GPU'.\n",
        "\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/gpu_runtime.png)\n",
        "\n",
        "If TensorFlow Graphics is not installed on your system, the following cell will\n",
        "install the TensorFlow Graphics package for you."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "26AvKq8MJRGl",
        "colab": {}
      },
      "source": [
        "!pip install tensorflow_graphics"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UkPKOuyJKuKM"
      },
      "source": [
        "Now that TensorFlow Graphics and dependencies are installed, let's import everything needed to run the demos contained in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "KlBviBxue7n0",
        "colab": {}
      },
      "source": [
        "import glob\n",
        "import os\n",
        "import tensorflow.compat.v1 as tf\n",
        "tf.disable_v2_behavior()\n",
        "\n",
        "from tensorflow_graphics.nn.layer import graph_convolution as graph_conv\n",
        "from tensorflow_graphics.notebooks import mesh_segmentation_dataio as dataio\n",
        "from tensorflow_graphics.notebooks import mesh_viewer"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AGaDtH49dlJb"
      },
      "source": [
        "Note this notebook works best in Graph mode."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6Gh-ZSwXnB-5"
      },
      "source": [
        "### Fetch model files and data\n",
        "\n",
        "For convenience, we provide a pre-trained model. Let's now download a pre-trained model checkpoint and the test data. The meshes are generated using Unity Multipurpose Avatar system [UMA](https://assetstore.unity.com/packages/3d/characters/uma-2-unity-multipurpose-avatar-35611)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_ZkB3iIcvvzJ",
        "colab": {}
      },
      "source": [
        "path_to_model_zip = tf.keras.utils.get_file(\n",
        "    'model.zip',\n",
        "    origin='https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/model.zip',\n",
        "    extract=True)\n",
        "\n",
        "path_to_data_zip = tf.keras.utils.get_file(\n",
        "    'data.zip',\n",
        "    origin='https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/data.zip',\n",
        "    extract=True)\n",
        "\n",
        "local_model_dir = os.path.join(os.path.dirname(path_to_model_zip), 'model')\n",
        "test_data_files = [\n",
        "    os.path.join(\n",
        "        os.path.dirname(path_to_data_zip),\n",
        "        'data/Dancer_test_sequence.tfrecords')\n",
        "]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dmh4b6VKcATt"
      },
      "source": [
        "## Load and visualize test data\n",
        "\n",
        "For graph convolutions, we need a *weighted adjacency matrix* denoting the mesh\n",
        "connectivity. Feature-steered graph convolutions expect self-edges in the mesh\n",
        "connectivity for each vertex, i.e. the diagonal of the weighted adjacency matrix\n",
        "should be non-zero. This matrix is defined as:\n",
        "```\n",
        "A[i, j] = w[i,j] if vertex i and vertex j share an edge,\n",
        "A[i, i] = w[i,i] for each vertex i,\n",
        "A[i, j] = 0 otherwise.\n",
        "where, w[i, j] = 1/(degree(vertex i)), and sum(j)(w[i,j]) = 1\n",
        "```\n",
        "Here degree(vertex i) is the number of edges incident on a vertex (including the\n",
        "self-edge). This weighted adjacency matrix is stored as a SparseTensor.\n",
        "\n",
        "We will load the test meshes from the test [tf.data.TFRecordDataset](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset)\n",
        "downloaded above. Each mesh is stored as a\n",
        "[tf.Example](https://www.tensorflow.org/api_docs/python/tf/train/Example), with\n",
        "the following fields:\n",
        "\n",
        "*   'num_vertices': Number of vertices in each mesh.\n",
        "*   'num_triangles': Number of triangles in each mesh.\n",
        "*   'vertices': A [V, 3] float tensor of vertex positions.\n",
        "*   'triangles': A [T, 3] integer tensor of vertex indices for each triangle.\n",
        "*   'labels': A [V] integer tensor with segmentation class label for each\n",
        "    vertex.\n",
        "\n",
        "where 'V' is number of vertices and 'T' is number of triangles in the mesh. As\n",
        "each mesh may have a varying number of vertices and faces (and the corresponding\n",
        "connectivity matrix), we pad the data tensors with '0's in each batch.\n",
        "\n",
        "For details on the dataset pipeline implementation, take a look at\n",
        "mesh_segmentation_dataio.py.\n",
        "\n",
        "Let's try to load a batch from the test TFRecordDataset, and visualize the first\n",
        "mesh with each vertex colored by the part label."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LZM02o0pEny6",
        "colab": {}
      },
      "source": [
        "test_io_params = {\n",
        "    'is_training': False,\n",
        "    'sloppy': False,\n",
        "    'shuffle': True,\n",
        "}\n",
        "test_tfrecords = test_data_files\n",
        "\n",
        "\n",
        "input_graph = tf.Graph()\n",
        "with input_graph.as_default():\n",
        "  mesh_load_op = dataio.create_input_from_dataset(\n",
        "      dataio.create_dataset_from_tfrecords, test_tfrecords, test_io_params)\n",
        "  with tf.Session() as sess:\n",
        "    test_mesh_data, test_labels = sess.run(mesh_load_op)\n",
        "\n",
        "input_mesh_data = {\n",
        "    'vertices': test_mesh_data['vertices'][0, ...],\n",
        "    'faces': test_mesh_data['triangles'][0, ...],\n",
        "    'vertex_colors': mesh_viewer.SEGMENTATION_COLORMAP[test_labels[0, ...]],\n",
        "}\n",
        "input_viewer = mesh_viewer.Viewer(input_mesh_data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aqV6vkCkWB7J"
      },
      "source": [
        "## Model Definition\n",
        "\n",
        "Given a mesh with V vertices and D-dimensional per-vertex input features (e.g.\n",
        "vertex position, normal), we would like to create a network capable of\n",
        "classifying each vertex to a part label. Let's first create a mesh encoder that\n",
        "encodes each vertex in the mesh into C-dimensional logits, where C is the number\n",
        "of parts. First we use 1x1 convolutions to change input feature dimensions,\n",
        "followed by a sequence of feature steered graph convolutions and ReLU\n",
        "non-linearities, and finally 1x1 convolutions to logits, which are used for\n",
        "computing softmax cross entropy as described below.\n",
        "\n",
        "Note that this model does not use any form of pooling, which is outside the scope of this notebook.\n",
        "\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/mesh_segmentation_model_def.png)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fQVeuGazM0LK",
        "colab": {}
      },
      "source": [
        "MODEL_PARAMS = {\n",
        "    'num_filters': 8,\n",
        "    'num_classes': 16,\n",
        "    'encoder_filter_dims': [32, 64, 128],\n",
        "}\n",
        "\n",
        "\n",
        "def mesh_encoder(batch_mesh_data, num_filters, output_dim, conv_layer_dims):\n",
        "  \"\"\"A mesh encoder using feature steered graph convolutions.\n",
        "\n",
        "    The shorthands used below are\n",
        "      `B`: Batch size.\n",
        "      `V`: The maximum number of vertices over all meshes in the batch.\n",
        "      `D`: The number of dimensions of input vertex features, D=3 if vertex\n",
        "        positions are used as features.\n",
        "\n",
        "  Args:\n",
        "    batch_mesh_data: A mesh_data dict with following keys\n",
        "      'vertices': A [B, V, D] `float32` tensor of vertex features, possibly\n",
        "        0-padded.\n",
        "      'neighbors': A [B, V, V] `float32` sparse tensor of edge weights.\n",
        "      'num_vertices': A [B] `int32` tensor of number of vertices per mesh.\n",
        "    num_filters: The number of weight matrices to be used in feature steered\n",
        "      graph conv.\n",
        "    output_dim: A dimension of output per vertex features.\n",
        "    conv_layer_dims: A list of dimensions used in graph convolution layers.\n",
        "\n",
        "  Returns:\n",
        "    vertex_features: A [B, V, output_dim] `float32` tensor of per vertex\n",
        "      features.\n",
        "  \"\"\"\n",
        "  batch_vertices = batch_mesh_data['vertices']\n",
        "\n",
        "  # Linear: N x D --> N x 16.\n",
        "  vertex_features = tf.keras.layers.Conv1D(16, 1, name='lin16')(batch_vertices)\n",
        "\n",
        "  # graph convolution layers\n",
        "  for dim in conv_layer_dims:\n",
        "    with tf.variable_scope('conv_%d' % dim):\n",
        "      vertex_features = graph_conv.feature_steered_convolution_layer(\n",
        "          vertex_features,\n",
        "          batch_mesh_data['neighbors'],\n",
        "          batch_mesh_data['num_vertices'],\n",
        "          num_weight_matrices=num_filters,\n",
        "          num_output_channels=dim)\n",
        "    vertex_features = tf.nn.relu(vertex_features)\n",
        "\n",
        "  # Linear: N x 128 --> N x 256.\n",
        "  vertex_features = tf.keras.layers.Conv1D(\n",
        "      256, 1, name='lin256')(\n",
        "          vertex_features)\n",
        "  vertex_features = tf.nn.relu(vertex_features)\n",
        "\n",
        "  # Linear: N x 256 --> N x output_dim.\n",
        "  vertex_features = tf.keras.layers.Conv1D(\n",
        "      output_dim, 1, name='lin_output')(\n",
        "          vertex_features)\n",
        "\n",
        "  return vertex_features"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6c2pz4r_79F_"
      },
      "source": [
        "Given a mesh encoder, let's define a model_fn for a custom\n",
        "[tf.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)\n",
        "for vertex classification using softmax cross entropy loss. A tf.Estimator model_fn returns the ops necessary to perform training, evaluation, or predictions given inputs and a number of other parameters. Recall that the\n",
        "vertex tensor may be zero-padded (see Dataset Pipeline above), hence we must mask out the contribution from the padded values."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WE-cuv0i78ak",
        "colab": {}
      },
      "source": [
        "def get_learning_rate(params):\n",
        "  \"\"\"Returns a decaying learning rate.\"\"\"\n",
        "  global_step = tf.train.get_or_create_global_step()\n",
        "  learning_rate = tf.train.exponential_decay(\n",
        "      params['init_learning_rate'],\n",
        "      global_step,\n",
        "      params['lr_decay_steps'],\n",
        "      params['lr_decay_rate'])\n",
        "  return learning_rate\n",
        "\n",
        "def model_fn(features, labels, mode, params):\n",
        "  \"\"\"Returns a mesh segmentation model_fn for use with tf.Estimator.\"\"\"\n",
        "  logits = mesh_encoder(features, params['num_filters'], params['num_classes'],\n",
        "                        params['encoder_filter_dims'])\n",
        "  predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)\n",
        "  outputs = {\n",
        "      'vertices': features['vertices'],\n",
        "      'triangles': features['triangles'],\n",
        "      'num_vertices': features['num_vertices'],\n",
        "      'num_triangles': features['num_triangles'],\n",
        "      'predictions': predictions,\n",
        "  }\n",
        "  # For predictions, return the outputs.\n",
        "  if mode == tf.estimator.ModeKeys.PREDICT:\n",
        "    outputs['labels'] = features['labels']\n",
        "    return tf.estimator.EstimatorSpec(mode=mode, predictions=outputs)\n",
        "  # Loss\n",
        "  # Weight the losses by masking out padded vertices/labels.\n",
        "  vertex_ragged_sizes = features['num_vertices']\n",
        "  mask = tf.sequence_mask(vertex_ragged_sizes, tf.shape(labels)[-1])\n",
        "  loss_weights = tf.cast(mask, dtype=tf.float32)\n",
        "  loss = tf.losses.sparse_softmax_cross_entropy(\n",
        "      logits=logits, labels=labels, weights=loss_weights)\n",
        "  # For training, build the optimizer.\n",
        "  if mode == tf.estimator.ModeKeys.TRAIN:\n",
        "    optimizer = tf.train.AdamOptimizer(\n",
        "        learning_rate=get_learning_rate(params),\n",
        "        beta1=params['beta'],\n",
        "        epsilon=params['adam_epsilon'])\n",
        "    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
        "    with tf.control_dependencies(update_ops):\n",
        "      train_op = optimizer.minimize(\n",
        "          loss=loss, global_step=tf.train.get_global_step())\n",
        "    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)\n",
        "\n",
        "  # For eval, return eval metrics.\n",
        "  eval_ops = {\n",
        "      'mean_loss':\n",
        "          tf.metrics.mean(loss),\n",
        "      'accuracy':\n",
        "          tf.metrics.accuracy(\n",
        "              labels=labels, predictions=predictions, weights=loss_weights)\n",
        "  }\n",
        "  return tf.estimator.EstimatorSpec(\n",
        "      mode=mode, loss=loss, eval_metric_ops=eval_ops)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "94FICCro_dLV"
      },
      "source": [
        "## Test model & visualize results\n",
        "\n",
        "Now that we have defined the model, let's load the weights from the trained model downloaded above and use tf.Estimator.predict to predict the part labels for meshes in the test dataset."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Olj5zIkg72FK",
        "colab": {}
      },
      "source": [
        "test_io_params = {\n",
        "    'is_training': False,\n",
        "    'sloppy': False,\n",
        "    'shuffle': True,\n",
        "    'repeat': False\n",
        "}\n",
        "test_tfrecords = test_data_files\n",
        "\n",
        "def predict_fn():\n",
        "  return dataio.create_input_from_dataset(dataio.create_dataset_from_tfrecords,\n",
        "                                          test_tfrecords,\n",
        "                                          test_io_params)\n",
        "\n",
        "\n",
        "estimator = tf.estimator.Estimator(model_fn=model_fn,\n",
        "                                   model_dir=local_model_dir,\n",
        "                                   params=MODEL_PARAMS)\n",
        "test_predictions = estimator.predict(input_fn=predict_fn)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IO1VmbL087xf"
      },
      "source": [
        "Run the following cell repeatedly to cycle through the meshes in the test sequence. The left view shows the input mesh, and the right view shows the predicted part labels."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "xuoVe70D5PAF",
        "colab": {}
      },
      "source": [
        "prediction = next(test_predictions)\n",
        "input_mesh_data = {\n",
        "    'vertices': prediction['vertices'],\n",
        "    'faces': prediction['triangles'],\n",
        "}\n",
        "predicted_mesh_data = {\n",
        "    'vertices': prediction['vertices'],\n",
        "    'faces': prediction['triangles'],\n",
        "    'vertex_colors': mesh_viewer.SEGMENTATION_COLORMAP[prediction['predictions']],\n",
        "}\n",
        "\n",
        "input_viewer = mesh_viewer.Viewer(input_mesh_data)\n",
        "prediction_viewer = mesh_viewer.Viewer(predicted_mesh_data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QE8Meyi5GKWA"
      },
      "source": [
        "## Train the model from scratch\n",
        "\n",
        "Now let's train the mesh segmentation model from scratch. First we will download the train dataset files, and use tf.Estimator.train_and_evaluate to train a model.\n",
        "\n",
        "Note: Training code is provided inside colab for demonstration, and may be slow. For optimal performance, consider running the training process as a command line process, and a tensorboard process to track."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "S6vxJ8t5HRcS",
        "colab": {}
      },
      "source": [
        "path_to_train_data_zip = tf.keras.utils.get_file(\n",
        "    'train_data.zip',\n",
        "    origin='https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/train_data.zip',\n",
        "    extract=True)\n",
        "\n",
        "train_data_files = glob.glob(\n",
        "    os.path.join(os.path.dirname(path_to_train_data_zip), '*train*.tfrecords'))\n",
        "\n",
        "retrain_model_dir = os.path.join(local_model_dir, 'retrain')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "inNSvIN-HcBg",
        "colab": {}
      },
      "source": [
        "train_io_params = {\n",
        "    'batch_size': 8,\n",
        "    'parallel_threads': 8,\n",
        "    'is_training': True,\n",
        "    'shuffle': True,\n",
        "    'sloppy': True,\n",
        "}\n",
        "\n",
        "eval_io_params = {\n",
        "    'batch_size': 8,\n",
        "    'parallel_threads': 8,\n",
        "    'is_training': False,\n",
        "    'shuffle': False\n",
        "}\n",
        "\n",
        "\n",
        "def train_fn():\n",
        "  return dataio.create_input_from_dataset(dataio.create_dataset_from_tfrecords,\n",
        "                                          train_data_files, train_io_params)\n",
        "\n",
        "\n",
        "def eval_fn():\n",
        "  return dataio.create_input_from_dataset(dataio.create_dataset_from_tfrecords,\n",
        "                                          test_data_files, eval_io_params)\n",
        "\n",
        "\n",
        "train_params = {\n",
        "    'beta': 0.9,\n",
        "    'adam_epsilon': 1e-8,\n",
        "    'init_learning_rate': 0.001,\n",
        "    'lr_decay_steps': 10000,\n",
        "    'lr_decay_rate': 0.95,\n",
        "}\n",
        "\n",
        "train_params.update(MODEL_PARAMS)\n",
        "\n",
        "checkpoint_delay = 120  # Checkpoint every 2 minutes.\n",
        "max_steps = 100000  # Number of training steps.\n",
        "\n",
        "config = tf.estimator.RunConfig(\n",
        "    log_step_count_steps=1,\n",
        "    save_checkpoints_secs=checkpoint_delay,\n",
        "    keep_checkpoint_max=3)\n",
        "\n",
        "classifier = tf.estimator.Estimator(\n",
        "    model_fn=model_fn,\n",
        "    model_dir=retrain_model_dir,\n",
        "    config=config,\n",
        "    params=train_params)\n",
        "train_spec = tf.estimator.TrainSpec(input_fn=train_fn, max_steps=max_steps)\n",
        "eval_spec = tf.estimator.EvalSpec(\n",
        "    input_fn=eval_fn,\n",
        "    steps=None,\n",
        "    start_delay_secs=2 * checkpoint_delay,\n",
        "    throttle_secs=checkpoint_delay)\n",
        "\n",
        "print('Start training & eval.')\n",
        "tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)\n",
        "print('Train and eval done.')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "mesh_segmentation_demo.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
